{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from argparse import Namespace\n",
    "import collections\n",
    "import nltk.data\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import re\n",
    "import string\n",
    "from tqdm import tqdm_notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = Namespace(\n",
    "    raw_dataset_txt=\"data/books/frankenstein.txt\",\n",
    "    window_size=5,\n",
    "    train_proportion=0.7,\n",
    "    val_proportion=0.15,\n",
    "    test_proportion=0.15,\n",
    "    output_munged_csv=\"data/books/frankenstein_with_splits.csv\",\n",
    "    seed=1337\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split the raw text book into sentences\n",
    "tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')\n",
    "with open(args.raw_dataset_txt) as fp:\n",
    "    book = fp.read()\n",
    "sentences = tokenizer.tokenize(book)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3427 sentences\n",
      "Sample: No incidents have hitherto befallen us that would make a figure in a\n",
      "letter.\n"
     ]
    }
   ],
   "source": [
    "print (len(sentences), \"sentences\")\n",
    "print (\"Sample:\", sentences[100])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Clean sentences\n",
    "def preprocess_text(text):\n",
    "    text = ' '.join(word.lower() for word in text.split(\" \"))\n",
    "    text = re.sub(r\"([.,!?])\", r\" \\1 \", text)\n",
    "    text = re.sub(r\"[^a-zA-Z.,!?]+\", r\" \", text)\n",
    "    return text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "cleaned_sentences = [preprocess_text(sentence) for sentence in sentences]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Global vars\n",
    "MASK_TOKEN = \"<MASK>\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1fd40a8ddcd54427898109cfd74b4798",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=3427), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3619638283c745ba96319ddb07156b6d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=90698), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Create windows\n",
    "flatten = lambda outer_list: [item for inner_list in outer_list for item in inner_list]\n",
    "windows = flatten([list(nltk.ngrams([MASK_TOKEN] * args.window_size + sentence.split(' ') + \\\n",
    "    [MASK_TOKEN] * args.window_size, args.window_size * 2 + 1)) \\\n",
    "    for sentence in tqdm_notebook(cleaned_sentences)])\n",
    "\n",
    "# Create cbow data\n",
    "data = []\n",
    "for window in tqdm_notebook(windows):\n",
    "    target_token = window[args.window_size]\n",
    "    context = []\n",
    "    for i, token in enumerate(window):\n",
    "        if token == MASK_TOKEN or i == args.window_size:\n",
    "            continue\n",
    "        else:\n",
    "            context.append(token)\n",
    "    data.append([' '.join(token for token in context), target_token])\n",
    "    \n",
    "            \n",
    "# Convert to dataframe\n",
    "cbow_data = pd.DataFrame(data, columns=[\"context\", \"target\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create split data\n",
    "n = len(cbow_data)\n",
    "def get_split(row_num):\n",
    "    if row_num <= n*args.train_proportion:\n",
    "        return 'train'\n",
    "    elif (row_num > n*args.train_proportion) and (row_num <= n*args.train_proportion + n*args.val_proportion):\n",
    "        return 'val'\n",
    "    else:\n",
    "        return 'test'\n",
    "cbow_data['split']= cbow_data.apply(lambda row: get_split(row.name), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>context</th>\n",
       "      <th>target</th>\n",
       "      <th>split</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>, or the modern prometheus</td>\n",
       "      <td>frankenstein</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>frankenstein or the modern prometheus by</td>\n",
       "      <td>,</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>frankenstein , the modern prometheus by mary</td>\n",
       "      <td>or</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>frankenstein , or modern prometheus by mary wo...</td>\n",
       "      <td>the</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>frankenstein , or the prometheus by mary wolls...</td>\n",
       "      <td>modern</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                             context        target  split\n",
       "0                         , or the modern prometheus  frankenstein  train\n",
       "1           frankenstein or the modern prometheus by             ,  train\n",
       "2       frankenstein , the modern prometheus by mary            or  train\n",
       "3  frankenstein , or modern prometheus by mary wo...           the  train\n",
       "4  frankenstein , or the prometheus by mary wolls...        modern  train"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cbow_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Write split data to file\n",
    "cbow_data.to_csv(args.output_munged_csv, index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nlpbook",
   "language": "python",
   "name": "nlpbook"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  },
  "toc": {
   "colors": {
    "hover_highlight": "#DAA520",
    "running_highlight": "#FF0000",
    "selected_highlight": "#FFD700"
   },
   "moveMenuLeft": true,
   "nav_menu": {
    "height": "12px",
    "width": "252px"
   },
   "navigate_menu": true,
   "number_sections": true,
   "sideBar": true,
   "threshold": "5",
   "toc_cell": false,
   "toc_section_display": "block",
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
